package gorm

import (
	"context"
	"errors"
	"sync"

	"gitee.com/zacyuan/yuan/pkg/config"
	"gitee.com/zacyuan/yuan/pkg/database/gorm/serializer"
	"gorm.io/driver/mysql"
	"gorm.io/driver/postgres"
	"gorm.io/driver/sqlite"
	"gorm.io/gorm"
	"gorm.io/gorm/schema"
)

const (
	configNode = "databases"
)

var (
	dbs  sync.Map
	once sync.Once
)

func GetDefault(ctx ...context.Context) *DB {
	return GetDBByNode("default", ctx...)
}

func GetDBByNode(node string, ctx ...context.Context) *DB {
	sql, ok := dbs.Load(node)
	if !ok {
		return nil
	}

	if len(ctx) > 0 {
		return sql.(*DB).WithContext(ctx[0])
	}

	return sql.(*DB)
}

func Open(cfg *Config) (*DB, error) {
	if cfg.Name == "" {
		return nil, errors.New("name is empty.")
	}

	var dial gorm.Dialector
	switch cfg.Driver {
	case "", "mysql":
		dial = mysql.Open(cfg.DataSource)
	case "sqlite":
		dial = sqlite.Open(cfg.DataSource)
	case "postgres":
		dial = postgres.Open(cfg.DataSource)
	}

	logger := &GormLogger{}
	if cfg.WriteLog {
		cfg.GormConfig.Logger = logger
	}

	sql, err := gorm.Open(dial, cfg.GormConfig)
	if err != nil {
		return nil, err
	}

	sqlDB, err := sql.DB()
	if err != nil {
		return nil, err
	}
	sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
	sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)

	err = RegisterHook(sql, cfg.Interceptors...)
	if err != nil {
		return nil, err
	}

	dbs.Store(cfg.Name, sql)

	return sql, nil
}

func Start() error {
	var err error
	once.Do(func() {
		schema.RegisterSerializer("date", serializer.DateSerializer{})

		list := make(map[string]any)
		err = config.UnmarshalKey(configNode, &list)
		if err != nil {
			return
		}

		for one := range list {
			key := configNode + "." + one
			cfg := DefaultConfig()
			err = config.UnmarshalKey(key, cfg)
			if err != nil {
				return
			}

			if cfg.Startup {
				if cfg.Name == "" {
					cfg.Name = one
				}
				_, err = Open(cfg)
				if err != nil {
					return
				}
			}
		}
	})

	return err
}
